package com.mozilla.secops.postprocessing;

import com.mozilla.secops.DocumentingTransform;
import com.mozilla.secops.IOOptions;
import com.mozilla.secops.OutputOptions;
import com.mozilla.secops.Watchlist;
import com.mozilla.secops.alert.Alert;
import com.mozilla.secops.alert.AlertFormatter;
import com.mozilla.secops.alert.AlertMeta;
import com.mozilla.secops.input.Input;
import com.mozilla.secops.metrics.CfgTickBuilder;
import com.mozilla.secops.metrics.CfgTickProcessor;
import com.mozilla.secops.parser.Event;
import com.mozilla.secops.parser.EventFilter;
import com.mozilla.secops.parser.EventFilterRule;
import com.mozilla.secops.parser.ParserCfg;
import com.mozilla.secops.parser.ParserDoFn;
import com.mozilla.secops.parser.Payload;
import com.mozilla.secops.state.StateException;
import com.mozilla.secops.window.GlobalTriggers;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Objects;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** {@link PostProcessing} implements analysis of alerts generated by other pipelines. */
public class PostProcessing implements Serializable {
  private static final long serialVersionUID = 1L;

  /** Prefix for metrics namespace component */
  public static final String METRICS_NAMESPACE = "postprocessing";

  /** Alert processing time for watchlist */
  public static final String WATCHLIST_ALERT_PROCESSING_TIME_METRIC = "alert_processing_time";

  /**
   * Parse incoming events and filter to only include events of type {@link
   * com.mozilla.secops.parser.Alert}
   */
  public static class Parse extends PTransform<PCollection<String>, PCollection<Event>> {
    private static final long serialVersionUID = 1L;

    private Logger log;
    private final ParserCfg cfg;

    /**
     * Static initializer for {@link Parse} using specified pipeline options
     *
     * @param options Pipeline options
     */
    public Parse(PostProcessingOptions options) {
      log = LoggerFactory.getLogger(Parse.class);
      cfg = ParserCfg.fromInputOptions(options);
    }

    @Override
    public PCollection<Event> expand(PCollection<String> col) {
      EventFilter filter = new EventFilter().passConfigurationTicks();
      filter.addRule(new EventFilterRule().wantSubtype(Payload.PayloadType.ALERT));

      return col.apply(
          ParDo.of(new ParserDoFn().withConfiguration(cfg).withInlineEventFilter(filter)));
    }
  }

  /**
   * Check incoming alert events against a watchlist of various identifiers.
   *
   * <p>Uses {@link Watchlist} to retrieve the watchlist from datastore and check these entries
   * against alert metadata keys.
   *
   * <p>Since state based DoFn requires keyed input, we just expect a dummy key here to satisfy that
   * requirement. Elements passed to the function can all be keyed with the same value.
   */
  public static class WatchlistAnalyze extends DoFn<KV<Boolean, Alert>, Alert>
      implements DocumentingTransform {
    private static final long serialVersionUID = 1L;
    private Logger log;
    private Watchlist wl;
    private String warningEmail;
    private String criticalEmail;

    private static final int MAX_BATCH_SIZE = 250;
    private static final Duration MAX_BATCH_DURATION = Duration.standardSeconds(1);

    // Allocate an expiry timer, which will be used when we reach the end of window and is
    // independent of element count or processing time.
    @TimerId("alertExpiry")
    private final TimerSpec alertExpiry = TimerSpecs.timer(TimeDomain.EVENT_TIME);

    // Allocate a stale timer, which will be used after a certain period of processing time
    // elapses to dequeue any buffered elements.
    @TimerId("alertStale")
    private final TimerSpec alertStale = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);

    @StateId("alertBuffer")
    private final StateSpec<BagState<Alert>> alertBuffer = StateSpecs.bag();

    @StateId("alertBufferCount")
    private final StateSpec<ValueState<Integer>> alertBufferCount = StateSpecs.value();

    private final Distribution alertProcessingTime;

    private static final AlertMeta.Key[] emailKeys =
        new AlertMeta.Key[] {
          AlertMeta.Key.EMAIL, AlertMeta.Key.USERNAME, AlertMeta.Key.IDENTITY_KEY
        };
    private static final AlertMeta.Key[] ipKeys =
        new AlertMeta.Key[] {AlertMeta.Key.SOURCEADDRESS, AlertMeta.Key.SOURCEADDRESS_PREVIOUS};

    private static class KeyData {
      public String key;
      public String value;
      public String type;

      @Override
      public boolean equals(Object o) {
        if (o == this) {
          return true;
        }
        if (!(o instanceof KeyData)) {
          return false;
        }
        KeyData k = (KeyData) o;
        return key.equals(k.key) && value.equals(k.value) && type.equals(k.type);
      }

      @Override
      public int hashCode() {
        return Objects.hash(key, value, type);
      }

      KeyData(String key, String value, String type) {
        this.key = key;
        this.value = value;
        this.type = type;
      }
    }

    /**
     * Initialize WatchlistAnalyze with {@link PostProcessingOptions}
     *
     * @param options {@link PostProcessingOptions}
     */
    public WatchlistAnalyze(PostProcessingOptions options) {
      warningEmail = options.getWarningSeverityEmail();
      criticalEmail = options.getCriticalSeverityEmail();
      alertProcessingTime =
          Metrics.distribution(METRICS_NAMESPACE, WATCHLIST_ALERT_PROCESSING_TIME_METRIC);
    }

    /** {@inheritDoc} */
    public String getTransformDoc() {
      return "Alert on matched watchlist entries in incoming alerts from other pipelines.";
    }

    @Setup
    public void setup() throws IOException {
      log = LoggerFactory.getLogger(WatchlistAnalyze.class);
      try {
        wl = new Watchlist();
      } catch (StateException exc) {
        throw new RuntimeException(exc.getMessage());
      }
    }

    @Teardown
    public void teardown() throws IOException {
      wl.done();
    }

    @ProcessElement
    public void processElement(
        ProcessContext c,
        BoundedWindow w,
        @StateId("alertBuffer") BagState<Alert> alertBuffer,
        @StateId("alertBufferCount") ValueState<Integer> alertBufferCount,
        @TimerId("alertExpiry") Timer alertExpiry,
        @TimerId("alertStale") Timer alertStale) {
      long startTime = System.nanoTime();

      alertExpiry.set(w.maxTimestamp());

      Alert sourceAlert = c.element().getValue();

      Integer cnt = alertBufferCount.read();
      if (cnt == null) {
        cnt = 0;
      }
      if (cnt == 0) {
        alertStale.offset(MAX_BATCH_DURATION).setRelative();
      }
      cnt++;
      alertBufferCount.write(cnt);
      alertBuffer.add(sourceAlert);

      if (cnt >= MAX_BATCH_SIZE) {
        for (Alert a : processAlerts(alertBuffer.read())) {
          c.output(a);
        }
        alertBuffer.clear();
        alertBufferCount.clear();
      }
    }

    @OnTimer("alertStale")
    public void onStale(
        OnTimerContext c,
        @StateId("alertBuffer") BagState<Alert> alertBuffer,
        @StateId("alertBufferCount") ValueState<Integer> alertBufferCount) {
      if (!alertBuffer.isEmpty().read()) {
        for (Alert a : processAlerts(alertBuffer.read())) {
          c.output(a);
        }
      }
      alertBuffer.clear();
      alertBufferCount.clear();
    }

    @OnTimer("alertExpiry")
    public void onExpiry(
        OnTimerContext c,
        @StateId("alertBuffer") BagState<Alert> alertBuffer,
        @StateId("alertBufferCount") ValueState<Integer> alertBufferCount) {
      if (!alertBuffer.isEmpty().read()) {
        for (Alert a : processAlerts(alertBuffer.read())) {
          c.output(a);
        }
      }
      alertBuffer.clear();
      alertBufferCount.clear();
    }

    private ArrayList<KeyData> extractIpValues(Alert a) {
      ArrayList<KeyData> ret = new ArrayList<>();
      for (AlertMeta.Key i : ipKeys) {
        String v = a.getMetadataValue(i);
        if (v != null) {
          ret.add(new KeyData(i.getKey(), v, Watchlist.watchlistIpKind));
        }
      }
      return ret;
    }

    private ArrayList<KeyData> extractEmailValues(Alert a) {
      ArrayList<KeyData> ret = new ArrayList<>();
      for (AlertMeta.Key i : emailKeys) {
        String v = a.getMetadataValue(i);
        if (v != null) {
          // Some keys we may check are multi valued. Attempt to split the list and if its
          // not a list value, it must be a single value and can be added to the list.
          try {
            AlertMeta.splitListValues(i, v)
                .forEach(
                    email -> ret.add(new KeyData(i.getKey(), email, Watchlist.watchlistEmailKind)));
          } catch (IOException e) {
            ret.add(new KeyData(i.getKey(), v, Watchlist.watchlistEmailKind));
          }
        }
      }
      return ret;
    }

    private ArrayList<Alert> processAlerts(Iterable<Alert> input) {
      // First pull all the alerts from the iterable and store them in an ArrayList as we
      // will need to iterate over them more then once. While we are doing this, build a list
      // of values of the various types we need we will want to check.
      ArrayList<Alert> alerts = new ArrayList<>();
      ArrayList<String> emailValues = new ArrayList<>();
      ArrayList<String> ipValues = new ArrayList<>();
      for (Alert a : input) {
        alerts.add(a);

        for (KeyData i : extractIpValues(a)) {
          if (!ipValues.contains(i.value)) {
            ipValues.add(i.value);
          }
        }
        for (KeyData i : extractEmailValues(a)) {
          if (!emailValues.contains(i.value)) {
            emailValues.add(i.value);
          }
        }
      }
      log.info("processing {} alerts", alerts.size());

      // Query Watchlist for values; clear the original values arrays and move any matching
      // items here.
      ArrayList<Watchlist.WatchlistEntry> emailEntries =
          wl.getWatchlistEntries(Watchlist.watchlistEmailKind, emailValues);
      ArrayList<Watchlist.WatchlistEntry> ipEntries =
          wl.getWatchlistEntries(Watchlist.watchlistIpKind, ipValues);

      ArrayList<Alert> ret = new ArrayList<>();
      for (Alert a : alerts) {
        for (KeyData i : extractEmailValues(a)) {
          Watchlist.WatchlistEntry matchedEntry = evaluateKeyData(i, emailEntries);
          if (matchedEntry != null) {
            ret.add(createAlert(a, matchedEntry, i));
          }
        }
        for (KeyData i : extractIpValues(a)) {
          Watchlist.WatchlistEntry matchedEntry = evaluateKeyData(i, ipEntries);
          if (matchedEntry != null) {
            ret.add(createAlert(a, matchedEntry, i));
          }
        }
      }
      return ret;
    }

    private Watchlist.WatchlistEntry evaluateKeyData(
        KeyData k, ArrayList<Watchlist.WatchlistEntry> entries) {
      for (Watchlist.WatchlistEntry w : entries) {
        if (k.value.equals(w.getObject())) {
          return w;
        }
      }
      return null;
    }

    private Alert createAlert(Alert sourceAlert, Watchlist.WatchlistEntry entry, KeyData k) {
      Alert a = new Alert();
      a.setCategory("postprocessing");
      a.setSubcategory("watchlist");
      a.setSummary(
          String.format("matched watchlist object found in alert %s", sourceAlert.getAlertId()));
      a.setSeverity(entry.getSeverity());

      // Add escalation metadata
      if (entry.getSeverity() == Alert.AlertSeverity.WARNING) {
        a.addMetadata(AlertMeta.Key.NOTIFY_EMAIL_DIRECT, warningEmail);
      }
      if (entry.getSeverity() == Alert.AlertSeverity.CRITICAL) {
        a.addMetadata(AlertMeta.Key.NOTIFY_EMAIL_DIRECT, criticalEmail);
      }

      a.addMetadata(AlertMeta.Key.SOURCE_ALERT, sourceAlert.getAlertId().toString());
      a.addMetadata(AlertMeta.Key.MATCHED_METADATA_KEY, k.key);
      // This may seem redundant with the below `matched_object`, but trying to
      // future proof against adding regex matchers (or similar).
      a.addMetadata(AlertMeta.Key.MATCHED_METADATA_VALUE, k.value);
      a.addMetadata(AlertMeta.Key.MATCHED_TYPE, entry.getType());
      a.addMetadata(AlertMeta.Key.MATCHED_OBJECT, entry.getObject());
      a.addMetadata(AlertMeta.Key.WATCHLIST_CREATED_BY, entry.getCreatedBy());
      return a;
    }
  }

  /** Runtime options for {@link PostProcessing} pipeline. */
  public interface PostProcessingOptions extends PipelineOptions, IOOptions {
    @Description("Enable watchlist analysis")
    @Default.Boolean(true)
    Boolean getEnableWatchlistAnalysis();

    void setEnableWatchlistAnalysis(Boolean value);

    @Description("Email address to send warning level alerts to")
    String getWarningSeverityEmail();

    void setWarningSeverityEmail(String value);

    @Description("Email address to send critical level alerts to")
    String getCriticalSeverityEmail();

    void setCriticalSeverityEmail(String value);

    @Description("Enable alert summary analysis")
    @Default.Boolean(false)
    Boolean getEnableAlertSummaryAnalysis();

    void setEnableAlertSummaryAnalysis(Boolean value);

    @Description("Thresholds to use for alert summary analysis")
    String[] getAlertSummaryAnalysisThresholds();

    void setAlertSummaryAnalysisThresholds(String[] value);
  }

  /**
   * Build a configuration tick for PostProcessing given pipeline options
   *
   * @param options Pipeline options
   * @return String
   * @throws IOException IOException
   */
  public static String buildConfigurationTick(PostProcessingOptions options) throws IOException {
    CfgTickBuilder b = new CfgTickBuilder().includePipelineOptions(options);

    if (options.getEnableWatchlistAnalysis()) {
      b.withTransformDoc(new WatchlistAnalyze(options));
    }
    if (options.getEnableAlertSummaryAnalysis()) {
      b.withTransformDoc(new AlertSummary(options));
    }

    return b.build();
  }

  /**
   * Process input collection
   *
   * <p>Process collection of input events, returning a collection of alerts as required.
   *
   * @param input Input collection
   * @param options Pipeline options
   * @return Output collection
   */
  public static PCollection<Alert> processInput(
      PCollection<String> input, PostProcessingOptions options) {
    PCollectionList<Alert> alertList = PCollectionList.empty(input.getPipeline());

    PCollection<Event> inputEvents = input.apply("parse", new Parse(options));

    PCollection<Alert> inputAlerts = null;
    inputAlerts =
        inputEvents.apply(
            "extract alerts",
            ParDo.of(
                new DoFn<Event, Alert>() {
                  private static final long serialVersionUID = 1L;

                  @ProcessElement
                  public void processElement(ProcessContext c) {
                    Event e = c.element();
                    if (!e.getPayloadType().equals(Payload.PayloadType.ALERT)) {
                      return;
                    }
                    com.mozilla.secops.parser.Alert ae = e.getPayload();
                    c.output(ae.getAlert());
                  }
                }));

    if (options.getEnableWatchlistAnalysis()) {
      alertList =
          alertList.and(
              inputAlerts
                  .apply(
                      "watchlist key for state",
                      ParDo.of(
                          new DoFn<Alert, KV<Boolean, Alert>>() {
                            private static final long serialVersionUID = 1L;

                            @ProcessElement
                            public void processElement(ProcessContext c) {
                              c.output(KV.of(true, c.element()));
                            }
                          }))
                  .apply("watchlist analyze", ParDo.of(new WatchlistAnalyze(options)))
                  .apply("watchlist analyze rewindow for output", new GlobalTriggers<Alert>(5)));
    }
    if (options.getEnableAlertSummaryAnalysis()) {
      alertList =
          alertList.and(inputAlerts.apply("alert summary analysis", new AlertSummary(options)));
    }

    // If configuration ticks were enabled, enable the processor here too
    if (options.getGenerateConfigurationTicksInterval() > 0) {
      alertList =
          alertList.and(
              inputEvents
                  .apply(
                      "cfgtick processor", ParDo.of(new CfgTickProcessor("postprocessing-cfgtick")))
                  .apply(new GlobalTriggers<Alert>(5)));
    }

    return alertList.apply("flatten output", Flatten.<Alert>pCollections());
  }

  private static void runPostProcessing(PostProcessingOptions options)
      throws IllegalArgumentException {
    Pipeline p = Pipeline.create(options);

    PCollection<String> input;
    try {
      input =
          p.apply("input", Input.compositeInputAdapter(options, buildConfigurationTick(options)));
    } catch (IOException exc) {
      throw new RuntimeException(exc.getMessage());
    }
    processInput(input, options)
        .apply("output format", ParDo.of(new AlertFormatter(options)))
        .apply("output convert", MapElements.via(new AlertFormatter.AlertToString()))
        .apply("output", OutputOptions.compositeOutput(options));

    p.run();
  }

  /**
   * Entry point for Beam pipeline.
   *
   * @param args Runtime arguments.
   * @throws Exception Exception
   */
  public static void main(String[] args) throws Exception {
    PipelineOptionsFactory.register(PostProcessingOptions.class);
    PostProcessingOptions options =
        PipelineOptionsFactory.fromArgs(args).withValidation().as(PostProcessingOptions.class);
    runPostProcessing(options);
  }
}
